# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utils module."""
from __future__ import absolute_import

import os
import json
import subprocess
import tempfile
from pathlib import Path

from datetime import datetime
from typing import Literal, Any

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.shapes import Unassigned
from sagemaker.train import logger


def _default_bucket_and_prefix(session: Session) -> str:
    """Helper function to get the bucket name with the corresponding prefix if applicable

    Returns a string like:
    * ``default_bucket/default_bucket_prefix`` if the prefix is set
    * ``default_bucket`` if the prefix is not set

    Args:
        session (Session): The SageMaker session to use

    Returns:
        str: The bucket name with the prefix if applicable
    """
    if session.default_bucket_prefix is not None:
        return f"{session.default_bucket()}/{session.default_bucket_prefix}"
    return session.default_bucket()


def _default_s3_uri(session: Session, additional_path: str = "") -> str:
    """Helper function to get the default S3 URI for the SageMaker session.

    Returns a string like:
    * ``s3://default_bucket/default_bucket_prefix`` if the prefix is set
    * ``s3://default_bucket`` if the prefix is not set

    Args:
        session (Session): The SageMaker session to use
        additional_path (str): Additional path to append to the S3 URI. Defaults to "".

    Returns:
        str: The default S3 URI for the SageMaker session
    """
    bucket_and_prefix = _default_bucket_and_prefix(session)
    additional_path = additional_path.lstrip("/")  # Remove leading slash if present
    return (
        f"s3://{bucket_and_prefix}/{additional_path}"
        if additional_path
        else f"s3://{bucket_and_prefix}"
    )


def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
    """Check if the path is a valid S3 URI.

    This method checks if the path is a valid S3 URI. If the path_type is specified,
    it will also check if the path is a file or a directory.
    This method does not check if the S3 bucket or object exists.

    Args:
        path (str): S3 URI to validate
        path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate.
            Defaults to "Any".

    Returns:
        bool: True if the path is a valid S3 URI, False otherwise
    """
    # Check if the path is a valid S3 URI
    if not path.startswith("s3://"):
        return False

    if path_type == "File":
        # If it's a file, it should not end with a slash
        return not path.endswith("/")
    if path_type == "Directory":
        # If it's a directory, it should end with a slash
        return path.endswith("/")

    return path_type == "Any"


def _is_valid_path(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
    """Check if the path is a valid local path.

    Args:
        path (str): Local path to validate
        path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate.
            Defaults to "Any".

    Returns:
        bool: True if the path is a valid local path, False otherwise
    """
    if not os.path.exists(path):
        return False

    if path_type == "File":
        return os.path.isfile(path)
    if path_type == "Directory":
        return os.path.isdir(path)

    return path_type == "Any"


def _get_unique_name(base, max_length=63):
    """Generate a unique name based on the base name.

    This method generates a unique name based on the base name.
    The unique name is generated by appending the current timestamp
    to the base name.

    Args:
        base (str): The base name to use
        max_length (int): The maximum length of the unique name. Defaults to 63.

    Returns:
        str: The unique name
    """
    current_time = datetime.now().strftime("%Y%m%d%H%M%S")
    base = base.replace("_", "-")
    unique_name = f"{base}-{current_time}"
    unique_name = unique_name[:max_length]  # Truncate to max_length
    return unique_name


def _get_repo_name_from_image(image: str) -> str:
    """Get the repository name from the image URI.

    Example:
    ``` python
    _get_repo_name_from_image("123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest")
    # Returns "my-repo"
    ```

    Args:
        image (str): The image URI

    Returns:
        str: The repository name
    """
    return image.split("/")[-1].split(":")[0].split("@")[0]


def convert_unassigned_to_none(instance) -> Any:
    """Convert Unassigned values to None for any instance."""
    for name, value in instance.__dict__.items():
        if isinstance(value, Unassigned):
            setattr(instance, name, None)
    return instance


def safe_serialize(data):
    """Serialize the data without wrapping strings in quotes.

    This function handles the following cases:
    1. If `data` is a string, it returns the string as-is without wrapping in quotes.
    2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
       the JSON-encoded string using `json.dumps()`.
    3. If `data` cannot be serialized (e.g., a custom object), it returns the string
       representation of the data using `str(data)`.

    Args:
        data (Any): The data to serialize.

    Returns:
        str: The serialized JSON-compatible string or the string representation of the input.
    """
    if isinstance(data, str):
        return data
    try:
        return json.dumps(data)
    except TypeError:
        return str(data)


def _run_clone_command_silent(repo_url, dest_dir):
    """Run the 'git clone' command with the repo url and the directory to clone the repo into.

    Args:
        repo_url (str): Git repo url to be cloned.
        dest_dir: (str): Local path where the repo should be cloned into.

    Raises:
        CalledProcessError: If failed to clone git repo.
    """
    my_env = os.environ.copy()
    if repo_url.startswith("https://"):
        try:
            my_env["GIT_TERMINAL_PROMPT"] = "0"
            subprocess.check_call(
                ["git", "clone", repo_url, dest_dir],
                env=my_env,
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
            )
        except subprocess.CalledProcessError as e:
            logger.error(f"Failed to clone repository: {repo_url}")
            logger.error(f"Error output:\n{e}")
            raise
    elif repo_url.startswith("git@") or repo_url.startswith("ssh://"):
        try:
            with tempfile.TemporaryDirectory() as tmp_dir:
                custom_ssh_executable = Path(tmp_dir) / "ssh_batch"
                with open(custom_ssh_executable, "w") as pipe:
                    print("#!/bin/sh", file=pipe)
                    print("ssh -oBatchMode=yes $@", file=pipe)
                os.chmod(custom_ssh_executable, 0o511)
                my_env["GIT_SSH"] = str(custom_ssh_executable)
                subprocess.check_call(
                    ["git", "clone", repo_url, dest_dir],
                    env=my_env,
                    stdout=subprocess.DEVNULL,
                    stderr=subprocess.DEVNULL,
                )
        except subprocess.CalledProcessError as e:
            del my_env["GIT_SSH"]
            logger.error(f"Failed to clone repository: {repo_url}")
            logger.error(f"Error output:\n{e}")
            raise

def _get_studio_tags(model_id: str, hub_name: str):
    return [
        {
            "key": "sagemaker-studio:jumpstart-model-id",
            "value": model_id
        },
        {
            "key": "sagemaker-studio:jumpstart-hub-name",
            "value": hub_name
        }
    ]
